import pandas as p
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from collections import OrderedDict

#Load Goodman-Bacon decomposition data
def readData():
    file = 'Data/baconData.csv'
    data = p.read_csv(file)
    data = appendGroupSizes(data)
    print(data)
    return data


def getNumWithinGroup(data, treatment, control, stub, includeCO = True):
    col = '{}gp'.format(stub)

    if(treatment!=treatment):
        return 0

    indices = (data.loc[:, col]==str(treatment))|(data.loc[:, col]==str(control))
    tempData = data.loc[indices, :]
    states = set(tempData['state'])
    r = len(states)

    if(not includeCO and 'Colorado' in states):
        r = r-1

    return r

#Add group size to dataframe for plotting purposes
def appendGroupSizes(data):
    rows = []

    stubs = ['bacReg', 'bacVMT', 'bacGas']

    for ind, row in tqdm(data.iterrows(), desc='Appending group size.'):
        tempRow = []
        for stub in stubs:
            tempTreatment = row['{}T'.format(stub)]
            tempControl = row['{}C'.format(stub)]
            groupWCo = getNumWithinGroup(data, tempTreatment, tempControl, stub)
            groupWoCo = getNumWithinGroup(data, tempTreatment, tempControl, stub, False)
            if(stub=='bacReg'):
                tempRow.append(groupWoCo)
            else:
                tempRow.append(groupWCo)
        rows.append(tempRow)



    rows = np.array(rows)

    for i in range(len(stubs)):
        data.loc[:, '{}GroupSize'.format(stubs[i])] = rows[:, i]

    return data

#Plot goodman-bacon decomposition coefficients
def plotCoefs(data, coefStub, x=6, title=None, show=False, saveFile=None):
    coefCol = coefStub+'B'
    weightCol = coefStub+'S'
    typeCol = coefStub+'cgroup'
    sizeCol = coefStub+'GroupSize'

    data.loc[:, weightCol] = data.loc[:, weightCol]/np.nansum(data[weightCol])

    coefData = data.loc[data.loc[:, coefCol]==data.loc[:, coefCol], :]

    coefData = coefData.replace(0, 1)

    coef = np.nansum(data.loc[:, coefCol]*data.loc[:, weightCol])

    coefData.loc[:, 'Number of States'] = data.loc[:, sizeCol]
    coefData.loc[:, 'Coefficient Type'] = data.loc[:, typeCol]


    gr = (1+np.sqrt(5))/2
    y = x/gr

    fig, ax = plt.subplots(1,1, figsize=(x,y))
    hueOrder = data.loc[:, typeCol].astype(str)
    hueOrder = sorted(list(set(hueOrder[hueOrder!='nan'])))
    ax = sns.scatterplot(y=coefCol, x=weightCol, hue='Coefficient Type', hue_order=hueOrder,  data=coefData)


    xlim = plt.xlim()

    plt.plot(xlim, [coef, coef], color='xkcd:black', label='Average Coefficient')
    plt.xlim(xlim)
    plt.xlabel('Coefficient Weights')
    plt.ylabel('Grouping Coefficients')
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.02, 1))

    if(saveFile!=None):
        plt.savefig('Plots/Bacon/{}.png'.format(saveFile), bbox_inches='tight')

    if(show):
        plt.show()
    return coefData


#Plot goodman-bacon decomposition over time
def plotCoefsOverTime(data, coefStub, x=6, title=None, show=False, saveFile=None):
    coefCol = coefStub + 'B'
    weightCol = coefStub + 'S'
    treatCol = coefStub+'T'
    controlCol = coefStub+'C'
    typeCol = coefStub + 'cgroup'
    sizeCol = coefStub + 'GroupSize'

    coefData = data.loc[data.loc[:, coefCol] == data.loc[:, coefCol], :]
    print('Here')
    print(list(coefData.loc[:, treatCol]))
    print(list(coefData.loc[:, controlCol]))

    #Assign "always" and "never" groups a time for plotting purposes
    # data.loc[data.loc[:, typeCol]=='Never treated vs timing', treatCol] = 2017

    data.loc[data.loc[:, typeCol] == 'Always treated vs timing', treatCol] = 1970
    data.loc[data.loc[:, typeCol] == 'Always vs never treated', treatCol] = 1970

    data = data.loc[data.loc[:, treatCol] != 'Within', :]

    data.loc[:, treatCol] = data.loc[:, treatCol].astype(float)

    data.loc[:, weightCol] = data.loc[:, weightCol] / np.nansum(data[weightCol])

    coefData = data.loc[data.loc[:, coefCol] == data.loc[:, coefCol], :]

    coefData = coefData.replace(0, 1)

    coef = np.nansum(data.loc[:, coefCol] * data.loc[:, weightCol])

    coefData.loc[:, 'Number of States'] = data.loc[:, sizeCol]
    coefData.loc[:, 'Coefficient Type'] = data.loc[:, typeCol]

    coefData = coefData.sort_values([treatCol, controlCol])
    print('Here')
    print(coefData.loc[:, [treatCol, controlCol]])

    gr = (1 + np.sqrt(5)) / 2
    y = x / gr

    fig, ax = plt.subplots(1, 1, figsize=(x, y))

    ax = sns.scatterplot(y=coefCol, x=treatCol, hue=controlCol, size=weightCol,
                         data=coefData, legend=None)

    newXLocs = [1970, 1980, 1990, 2000, 2010, 2017]
    newXLabs = []

    for newXLoc in newXLocs:
        if(newXLoc==1970):
            newXLabs.append('Always')
        else:
            newXLabs.append(str(newXLoc))

    plt.xticks(newXLocs, newXLabs)

    xlim = plt.xlim()
    ax = plt.plot(xlim, [coef, coef], color='xkcd:black')
    plt.xlim(xlim)
    plt.xticks(rotation='vertical')
    plt.ylabel('Grouping Coefficients')
    plt.title(title)
    plt.xlabel('Treatment Year')

    if (saveFile != None):
        plt.savefig('Plots/Bacon/{}OverTime.png'.format(saveFile), bbox_inches='tight')

    if (show):
        plt.show()
    return coefData


data = readData()

stubs = OrderedDict(zip(['bacReg', 'bacVMT', 'bacGas'], ['Log Registrations', 'Log VMT', 'Log Highway Gas Use']))
coefData = OrderedDict()
for stub, title in stubs.items():
    tempCoefData = plotCoefs(data, stub, title='{} Goodman-Bacon Decomposition'.format(title), saveFile=stub)
    coefData[stub] = tempCoefData
    plotCoefsOverTime(data, stub, title='{} Goodman-Bacon \nDecomposition Over Time'.format(title), saveFile=stub)

# coefData = plotCoefs(data, 'bacVMT', title='Log VMT Bacon-Decomposition')
t = coefData['bacReg']